{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g0SYLNaw4ZkG"
      },
      "source": [
        "# Text Generation\n",
        "\n",
        "In this notebooks, we train an adapter for **GPT-2** that performs **poem generation**. We use a dataset of poem verses extracted from Project Gutenberg that is [available via HuggingFace datasets](https://huggingface.co/datasets/poem_sentiment).\n",
        "\n",
        "First, let's install all required libraries:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T12:15:51.447758300Z",
          "start_time": "2023-08-17T12:15:46.491925400Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "91hZktTHnvz6",
        "outputId": "c475a92a-4f1f-4e26-a3f4-31f23670c35d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/251.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/251.2 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -Uq adapters\n",
        "!pip install -q datasets\n",
        "!pip install -Uq accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fe_W__U6uFBs"
      },
      "source": [
        "Next, we need to download the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:27:10.813091500Z",
          "start_time": "2023-08-17T10:27:02.486414700Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zBuzA0h3n4-z",
        "outputId": "b7fe4eab-0e31-4790-c233-6745268be1bd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['id', 'verse_text', 'label'],\n",
            "        num_rows: 892\n",
            "    })\n",
            "    validation: Dataset({\n",
            "        features: ['id', 'verse_text', 'label'],\n",
            "        num_rows: 105\n",
            "    })\n",
            "    test: Dataset({\n",
            "        features: ['id', 'verse_text', 'label'],\n",
            "        num_rows: 104\n",
            "    })\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset = load_dataset(\"poem_sentiment\")\n",
        "print(dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "neZt29NKuT11"
      },
      "source": [
        "Before training, we need to preprocess the dataset. We tokenize the entries in the dataset and remove all columns we don't need to train the adapter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:27:28.269312400Z",
          "start_time": "2023-08-17T10:27:22.403182100Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 81,
          "referenced_widgets": [
            "d3fdcb28de1441c0864efa2e8f8116b9",
            "e471cf32f4024a5c877ac2b90afcdeff",
            "e77c608bd16a4681965f5c3565f397d9",
            "ace32f0902294633b0ac7e68dcdfd070",
            "3d2cafaf370d477fb1438a37a6cffedb",
            "7eff96f1a4d24080925d79d8d762bb00",
            "61d803c90d75495eb74f27d29fdfa8fa",
            "9aa5297615e9473cae0819c258bb3647",
            "4f6da33446e24bcb95af0c2ad88a0c2a",
            "32bf3939e64b49378a5084104add9956",
            "a9680458297944de83c90a0fb8f6bf94",
            "ba24329f322244649c532e1146ce7297",
            "ccdb6919534b45e2af8232d13cb6615d",
            "ac124b9e039a4bacadc3e45427d5f191",
            "c85416d9a34547a189bb84c07b90c377",
            "2d3f6f8c280d4ff489f2e5369fa9cae7",
            "865c3c1d46ad4e1cb6e72b011e9bc976",
            "ef94f58183ba4278834424e4f7eeeda4",
            "f26a50a35dad4710a871daa96c2048bd",
            "ee7d2806988f48d4b949f8f058823ca6",
            "39a2033f7857472abbb826884781cdf3",
            "7839557a4a9e46ef95f9e7aa2ce2016f"
          ]
        },
        "id": "d3Lz6h9Zo9t0",
        "outputId": "9be84939-0df3-416e-aef1-615a1958f23a"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d3fdcb28de1441c0864efa2e8f8116b9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/105 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ba24329f322244649c532e1146ce7297",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/104 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from transformers import GPT2Tokenizer\n",
        "\n",
        "def encode_batch(batch):\n",
        "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
        "  encoding = tokenizer(batch[\"verse_text\"])\n",
        "  # For language modeling the labels need to be the input_ids\n",
        "  #encoding[\"labels\"] = encoding[\"input_ids\"]\n",
        "  return encoding\n",
        "\n",
        "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
        "# The GPT-2 tokenizer does not have a padding token. In order to process the data\n",
        "# in batches we set one here\n",
        "tokenizer.pad_token = tokenizer.eos_token\n",
        "column_names = dataset[\"train\"].column_names\n",
        "dataset = dataset.map(encode_batch, remove_columns=column_names, batched=True)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jqJWsH5_ENkb"
      },
      "source": [
        "Next, we concatenate the documents in the dataset and create chunks with a length of `block_size`. This is beneficial for language modeling."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:27:42.403373100Z",
          "start_time": "2023-08-17T10:27:38.946111800Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 81,
          "referenced_widgets": [
            "19f8d13ac27b490ebbf548d7ac17ee35",
            "271a1b8bb08f4c80aec6e42724888071",
            "a760bdbd3fa24f0fadf43b7a90a9de15",
            "4f3fa27889184f13b4d9381f3d4fa36a",
            "5edfd841e3304894b6292627232ad20b",
            "bd72866bc2fb4d6db3db562cc7ce0b88",
            "5a584b69dced44f6b6c07be33bb50b28",
            "6ca9de574285462f92533311bf7dbec9",
            "fdc20d07dce446b0a31dfdb94e8b85d4",
            "ed35b967172449a8b692cd4488b6008e",
            "5c9bffdbc1f74185adac4ff7c2dd7eb8",
            "1618a0043b9d484095b23236256d4ea3",
            "d571f62a63624716a33a57131823ba35",
            "ff86f0b512ca4ff4b9ca3c5b4ebeb130",
            "2bbb05954ab54e90aa2ee2f04537f1d3",
            "7a766236aef949cda4ee1dde37b74e7f",
            "a2ce09a2d2084db890375ad52281d1de",
            "9793a6e3de424a2398abdd432def5068",
            "a3e4dd5032cb46cb93b0ead50621f636",
            "3e186d5f6ef44d79b6d26532138e92e7",
            "599eded89878424795773f9335588ea4",
            "6d6bee4b4a2a49ddbaa130e4c34af154"
          ]
        },
        "id": "3ff5lRmnA6R7",
        "outputId": "08e4c813-3c4e-4c90-b51f-762816d7013b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "19f8d13ac27b490ebbf548d7ac17ee35",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/105 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1618a0043b9d484095b23236256d4ea3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/104 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "block_size = 50\n",
        "# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n",
        "def group_texts(examples):\n",
        "  # Concatenate all texts.\n",
        "  concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
        "  total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
        "  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
        "  # customize this part to your needs.\n",
        "  total_length = (total_length // block_size) * block_size\n",
        "  # Split by chunks of max_len.\n",
        "  result = {\n",
        "    k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
        "    for k, t in concatenated_examples.items()\n",
        "  }\n",
        "  result[\"labels\"] = result[\"input_ids\"].copy()\n",
        "  return result\n",
        "\n",
        "dataset = dataset.map(group_texts,batched=True,)\n",
        "\n",
        "dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AT56RfOGFXvt"
      },
      "source": [
        "Next, we create the model and add our new adapter. Because we create the model from the `AutoModelForCausalLM` class from the `transformers` package and not directly from `adapters`, we first need to enable adapter support by calling the `init()` method. Then, we add the adapter, let's just call it `poem` since it is trained to create new poems. Finally, we activate it and prepare it for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:32:20.482874100Z",
          "start_time": "2023-08-17T10:32:17.685600500Z"
        },
        "id": "ioLpFbOfnPE6"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoModelForCausalLM\n",
        "from adapters import init\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
        "# Enable adapter support\n",
        "init(model)\n",
        "# Add new adapter\n",
        "model.add_adapter(\"poem\")\n",
        "# Activate adapter for training\n",
        "model.train_adapter(\"poem\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkCUz8B6Fw5E"
      },
      "source": [
        "The last thing we need to do before we can start training is creating the trainer. As trainings arguments, we choose a learning rate of 1e-4. Feel free to play around with the parameters and see how they affect the result."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:32:44.036010200Z",
          "start_time": "2023-08-17T10:32:43.627871Z"
        },
        "id": "stDwnIEApmNu"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainingArguments\n",
        "from adapters import AdapterTrainer\n",
        "training_args = TrainingArguments(\n",
        "  output_dir=\"./examples\",\n",
        "  do_train=True,\n",
        "  remove_unused_columns=False,\n",
        "  learning_rate=5e-4,\n",
        "  num_train_epochs=3,\n",
        ")\n",
        "\n",
        "\n",
        "trainer = AdapterTrainer(\n",
        "        model=model,\n",
        "        args=training_args,\n",
        "        tokenizer=tokenizer,\n",
        "        train_dataset=dataset[\"train\"],\n",
        "        eval_dataset=dataset[\"validation\"],\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 165
        },
        "id": "Mtjczmcxqv-l",
        "outputId": "73c7d317-ed13-41f3-e284-278e71e1fb97"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='66' max='66' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [66/66 00:07, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=66, training_loss=5.651855931137547, metrics={'train_runtime': 11.5993, 'train_samples_per_second': 45.52, 'train_steps_per_second': 5.69, 'total_flos': 13614563635200.0, 'train_loss': 5.651855931137547, 'epoch': 3.0})"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3pkNcsRzGCQo"
      },
      "source": [
        "Now that we have a trained adapter we save it for future usage."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:32:58.398931900Z",
          "start_time": "2023-08-17T10:32:57.257361300Z"
        },
        "id": "_Q_pKmUkqx7P"
      },
      "outputs": [],
      "source": [
        "model.save_adapter(\"adapter_poem\", \"poem\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjcY1WnaGIxi"
      },
      "source": [
        "Next, let's generate some poetry with our trained adapter. In order to do this, we create a GPT2LMHeadModel that is best suited for language generation. Then we load our trained adapter. Finally, we have to choose the start of our poem. If you want your poem to start differently just change `PREFIX` accordingly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:34:02.878723300Z",
          "start_time": "2023-08-17T10:34:00.038412300Z"
        },
        "id": "m4TML7t2mZRp"
      },
      "outputs": [],
      "source": [
        "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
        "\n",
        "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
        "# Enable adapter support\n",
        "init(model)\n",
        "# You can also load your locally trained adapter\n",
        "model.load_adapter(\"adapter_poem\")\n",
        "model.set_active_adapters(\"poem\")\n",
        "\n",
        "PREFIX = \"In the night\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I0tFdm1RGtma"
      },
      "source": [
        "For the generation, we need to tokenize the prefix first and then pass it to the model. In this case, we create five possible continuations for the beginning we chose."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:34:07.409085900Z",
          "start_time": "2023-08-17T10:34:04.455352700Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h6iiEOwfsdVd",
        "outputId": "f842cabc-7091-4323-b466-93b2fa0159d6"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        }
      ],
      "source": [
        "encoding = tokenizer(PREFIX, return_tensors=\"pt\")\n",
        "output_sequence = model.generate(\n",
        "  input_ids=encoding[\"input_ids\"],\n",
        "  attention_mask=encoding[\"attention_mask\"],\n",
        "  do_sample=True,\n",
        "  num_return_sequences=5,\n",
        "  max_length = 50,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SZ4uCZaZG68k"
      },
      "source": [
        "Lastly, we want to see what the model actually created. To do this, we need to decode the tokens from ids back to words and remove the EOS tokens. You can easily use this code with another dataset. Don't forget to share your adapters at [AdapterHub](https://adapterhub.ml/)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:34:11.515110800Z",
          "start_time": "2023-08-17T10:34:11.489394600Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TX3QpuWPtrol",
        "outputId": "aab36759-5ce0-41af-d1a2-fcc59cdab006",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "=== GENERATED SEQUENCE 1 ===\n",
            "In the night,when it's still,how shall he sing,the king's favorite music.\n",
            "\n",
            "for those of us who live in the old cities to-day:\n",
            "\n",
            "all the streets and the rivers,the rivers that ran\n",
            "=== GENERATED SEQUENCE 2 ===\n",
            "In the night,he died at the barroom on the morning,a silent man, a widow who had borne the ill son of his trolls.the son was long standing and the son remembered him for his service.even the wisest ma\n",
            "=== GENERATED SEQUENCE 3 ===\n",
            "In the night that night, a man left a stone in his hand,and in his face he is seen--as he looked round him he could not see,and when he did,the two men walked,he turned his head with one though\n",
            "=== GENERATED SEQUENCE 4 ===\n",
            "In the night, the sun had a bright ray to dazzle upon earth.and how, though they call his name, is he like the angel of hell upon the world?when he, who lives in his dream, says,--what i\n",
            "=== GENERATED SEQUENCE 5 ===\n",
            "In the night!not from the earth's gates!the sweet scent of sweet-day.\"they said,what thou dost say!\"as thou wilt read,there lied,and yet,when she sleepsin' a night o\n"
          ]
        }
      ],
      "source": [
        "for generated_sequence_idx, generated_sequence in enumerate(output_sequence):\n",
        "       print(\"=== GENERATED SEQUENCE {} ===\".format(generated_sequence_idx + 1))\n",
        "       generated_sequence = generated_sequence.tolist()\n",
        "\n",
        "       # Decode text\n",
        "       text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)\n",
        "       # Remove EndOfSentence Tokens\n",
        "       text = text[: text.find(tokenizer.eos_token)]\n",
        "\n",
        "       print(text)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.4"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
